import argparse
import torch
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
from sentence_transformers import SentenceTransformer, models
from text_pair_dataset import TextPairDataset
from tqdm import tqdm
import os

from heads import get_matching_head
from loss_func import get_loss_func
import random
import numpy as np

from torch.nn.utils.rnn import pad_sequence

def custom_collate_fn(batch):
    token_id_lists, reasons, labels = zip(*batch)

    token_id_lists = [torch.tensor(x, dtype=torch.long) for x in token_id_lists]

    labels = torch.tensor(labels, dtype=torch.float)
    return token_id_lists, list(reasons), labels


def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

def build_model(base_model="sentence-transformers/all-MiniLM-L6-v2"):
    embedder = models.Transformer(base_model, model_args={"trust_remote_code": True}, config_args={"trust_remote_code": True})
    pooling = models.Pooling(embedder.get_word_embedding_dimension())
    embedding_model = SentenceTransformer(modules=[embedder, pooling], trust_remote_code=True)
    return embedding_model

def train_and_save(args):
    if os.path.exists(os.path.join(args.save_dir, "embedding_model")) and \
       os.path.exists(os.path.join(args.save_dir, "matching_head.pt")):
        print(f"Model already exists at {args.save_dir}. Skipping training.")
        return
    set_seed(args.seed)

    embedding_model = build_model(args.model_name).cuda().train()
    embedding_dim = embedding_model.get_sentence_embedding_dimension()
    matching_head = get_matching_head(args.head_type, embedding_dim).cuda().train()

    if args.freeze_embedding_model:
        for param in embedding_model.parameters():
            param.requires_grad = False

    dataset = TextPairDataset(args.json_data_path, model_name=args.model_name, limit=args.limit)
    dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True, collate_fn=custom_collate_fn)


    tokenizer = embedding_model.tokenizer
    vocab = tokenizer.get_vocab()
    filtered_items = [(tok, idx) for tok, idx in vocab.items() if not tok.startswith("[") and tok.strip()]
    tokens = [x[0] for x in filtered_items]
    ids = [x[1] for x in filtered_items]

    with torch.no_grad():
        token_embs = embedding_model.encode(tokens, convert_to_tensor=True, normalize_embeddings=True, show_progress_bar=True)
    tokenid2emb = {int(i): emb.cuda() for i, emb in zip(ids, token_embs)}

    if args.freeze_embedding_model:
        optimizer = optim.Adam(matching_head.parameters(), lr=args.lr)
    else:
        optimizer = optim.Adam(list(embedding_model.parameters()) + list(matching_head.parameters()), lr=args.lr)

    loss_fn = get_loss_func(args.loss_type)

    for epoch in range(args.num_epochs):
        total_loss = 0
        for token_id_lists, reasons, labels in tqdm(dataloader, desc=f"Epoch {epoch+1}/{args.num_epochs}"):
            labels = labels.float().cuda()

            # sentence embedding
            tokenized_b = embedding_model.tokenize(reasons)
            tokenized_b = {k: v.to(embedding_model.device) for k, v in tokenized_b.items()}
            emb_b = embedding_model(tokenized_b)["sentence_embedding"]

            emb_a_list = []

            if args.token_mode == "mean":
                for token_ids in token_id_lists:
                    valid_embs = [tokenid2emb[i.item()] for i in token_ids if i.item() in tokenid2emb]
                    if valid_embs:
                        emb = torch.stack(valid_embs, dim=0).mean(dim=0)
                    else:
                        emb = torch.zeros(embedding_dim, device=embedding_model.device)
                    emb_a_list.append(emb)
                emb_a = torch.stack(emb_a_list, dim=0)

            elif args.token_mode == "seq":

                token_str_seqs = []
                inv_vocab = {v: k for k, v in vocab.items()}
                for token_ids in token_id_lists:
                    toks = [inv_vocab.get(i.item(), "") for i in token_ids if i.item() in inv_vocab]
                    toks = [tok for tok in toks if tok.strip()]
                    sentence = " ".join(toks)
                    token_str_seqs.append(sentence)

                tokenized_a = embedding_model.tokenize(token_str_seqs)
                tokenized_a = {k: v.to(embedding_model.device) for k, v in tokenized_a.items()}
                emb_a = embedding_model(tokenized_a)["sentence_embedding"]

            else:
                raise ValueError(f"Unsupported token_mode: {args.token_mode}")

            features = {"embedding_a": emb_a, "embedding_b": emb_b}
            outputs = matching_head(features)
            logits = outputs["logits"].squeeze(-1)

            # loss & backward
            loss = loss_fn(logits, labels, emb_a=emb_a, emb_b=emb_b)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            total_loss += loss.item()

        print(f"Epoch {epoch+1}, Loss: {total_loss/len(dataloader):.4f}")

        epoch_save_dir = os.path.join(args.save_dir, f"epoch-{epoch+1}")
        os.makedirs(epoch_save_dir, exist_ok=True)
        embedding_model.save(os.path.join(epoch_save_dir, "embedding_model"))
        torch.save(matching_head.state_dict(), os.path.join(epoch_save_dir, "matching_head.pt"))
        print(f"Models for epoch {epoch+1} saved to {epoch_save_dir}/")

    os.makedirs(args.save_dir, exist_ok=True)
    embedding_model.save(os.path.join(args.save_dir, "embedding_model"))
    torch.save(matching_head.state_dict(), os.path.join(args.save_dir, "matching_head.pt"))
    print(f"Models saved to {args.save_dir}/")

if __name__ == "__main__":
    parser = argparse.ArgumentParser()

    parser.add_argument("--model_name", type=str, default="sentence-transformers/all-MiniLM-L6-v2")
    parser.add_argument("--save_dir", type=str, default="saved_model_dir_1wdata")
    parser.add_argument("--json_data_path", type=str, required=True, help="Path to the token-reason style JSON file")
    parser.add_argument("--batch_size", type=int, default=128)
    parser.add_argument("--lr", type=float, default=2e-5)
    parser.add_argument("--num_epochs", type=int, default=3)
    parser.add_argument("--limit", type=int, default=0)
    parser.add_argument("--seed", type=int, default=42, help="Random seed for reproducibility")
    parser.add_argument(
        "--token_mode", type=str, default="mean", choices=["mean", "seq"],
        help="How to compute token-side embedding: 'mean' of token vectors or 'seq' embedding from text sequence"
    )

    parser.add_argument("--head_type", type=str, default="base", choices=["base", "deep_mlp", "cos_sim", "residual", "cross_attn", "feature", "cos_sim_deeper"])
    parser.add_argument("--loss_type", type=str, default="bce", choices=["bce", "focal", "contrastive", "circle", "weighted_bce", "auc_margin"])

    parser.add_argument("--freeze_embedding_model", action="store_true",
                        help="If set, freeze the embedding model and only train the matching head.")


    args = parser.parse_args()
    train_and_save(args)
